from dowel import logger, tabular
from torch.optim import Optimizer

import torch


class SHARPOptimizer(Optimizer):
    def __init__(self, params, a, b):

        self.power = 1.0 / 3.0
        self.a = a
        self.b = b
        self.eta = 2.5e-3
        self.alpha = 0.25
        self.iteration = -1
        defaults = dict()
        self.sqr_grads_norms = 0
        self.last_grad_norm = 0

        super(SHARPOptimizer, self).__init__(params, defaults)
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['current_point'] = torch.zeros_like(p)
                state['last_point'] = torch.zeros_like(p)

    def update_the_displacement(self, total_norm):
        """
        update the parameter based on the displacement
        """

        # clip the displacement
        for group in self.param_groups:
            with torch.no_grad():
                for p in group['params']:
                    state = self.state[p]
                    buf = state['momentum_buffer']
                    last_point, current_point = state['last_point'], state['current_point']

                    with torch.no_grad():
                        p.copy_(current_point - self.eta / total_norm * buf)
                        last_point.copy_(current_point)

    def update_model_to_random_line_point(self, b):
        """
        update the parameter based on the displacement
        """

        for group in self.param_groups:
            with torch.no_grad():
                for p in group['params']:
                    state = self.state[p]
                    last_point, current_point = state['last_point'], state['current_point']
                    with torch.no_grad():
                        p.copy_(b * current_point + (1 - b) * last_point)

    def save_current_point(self, ):

        for group in self.param_groups:
            with torch.no_grad():
                for p in group['params']:
                    state = self.state[p]
                    last_point, current_point = state['last_point'], state['current_point']
                    current_point.copy_(p)

    def generate_uniform_number(self, ):

        # return torch.rand(1).item()
        return torch.rand(1).item()

    def inner_product_of_list_var(self, array1_, array2_):

        """
        Args:
        param array1_: list of tensors
        param array2_: list of tensors
        return:
        The inner product of the flattened list
        """

        sum_list = 0
        for i in range(len(array1_)):
            sum_list += torch.sum(array1_[i] * array2_[i])
        return sum_list

    def compute_hvp(self, grads, g_ll, params, vector):
        # compute first term
        inner_product = self.inner_product_of_list_var(g_ll, vector)
        # compute second term
        second_term = torch.autograd.grad(outputs=grads, inputs=params, grad_outputs=vector,
                                          retain_graph=True)
        hessian_vector_product = []
        for i in range(len(grads)):
            hessian_vector_product.append(inner_product * grads[i] + second_term[i])
        return hessian_vector_product

    def step(self, closure=None):
        """Performs a single optimization step.
        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        self.iteration += 1
        g_square_norm = 0
        grad_square_norm = 0

        self.save_current_point()
        # update eta
        self.eta = self.a / (1 + self.iteration) ** (2. / 3)

        # update alpha
        if self.iteration == 0:
            self.alpha = None
        else:
            self.alpha = self.b / (1 + self.iteration) ** (2. / 3)

        for group in self.param_groups:
            vector = []
            grads = []

            for p in group['params']:
                if p.grad is None:
                    continue
                vector.append(self.state[p]['current_point'] - self.state[p]['last_point'])
                grads.append(p.grad.clone())

            if self.iteration >= 1:
                # compute hessian vector
                self.update_model_to_random_line_point(self.generate_uniform_number())
                with torch.enable_grad():
                    g_ll = closure()

                modified_grads = []
                modified_param = []
                for p in group['params']:
                    if p.grad is None:
                        continue
                    modified_grads.append(p.grad)
                    modified_param.append(p)

                hvp = self.compute_hvp(modified_grads, g_ll, modified_param, vector)

            with torch.no_grad():
                i = 0
                for p in group['params']:
                    state = self.state[p]
                    with torch.no_grad():
                        d_p = grads[i]
                        grad_square_norm += d_p.norm(2).item() ** 2

                        if self.alpha != 1:
                            if 'momentum_buffer' not in state:
                                buf = state['momentum_buffer'] = torch.clone(d_p).detach()
                            else:
                                buf = state['momentum_buffer']
                                buf.add_(hvp[i]).mul_(1 - self.alpha).add_(d_p, alpha=1)

                        else:
                            buf = state['momentum_buffer'] = torch.clone(d_p).detach()
                        d_p = buf
                        g_square_norm += d_p.norm(2).item() ** 2
                    i += 1

        g_norm = g_square_norm ** (1. / 2)

        print("norm of g", g_norm)
        print("norm of gradient", grad_square_norm ** (1. / 2))
        print("alpha", self.alpha)
        print("eta: ", self.eta)
        with tabular.prefix("SHARP" + '/'):
            tabular.record('norm of g', g_norm)
            tabular.record('norm of gradient', grad_square_norm ** (1. / 2))
            tabular.record('eta', self.eta)
            tabular.record('alpha', self.alpha)
            #             tabular.record('w', self.w)
            #             tabular.record('c', self.c)
            logger.log(tabular)
        self.update_the_displacement(g_norm)

        return loss
